{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 完整的训练流程\n",
    "1. 数据基于`https://github.com/hikariming/alpaca_chinese_dataset`\n",
    "2. 部分代码来源于`https://github.com/27182812/ChatGLM-chinese-insturct/blob/main/finetune.py`\n",
    "3. 基于我之前修改的`model_chatglm.py`做的一整套教程\n",
    "\n",
    "## 清洗数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 如果没有下载这个仓库，可以使用下面命令进行clone\n",
    "\n",
    "# !git clone https://github.com/hikariming/alpaca_chinese_dataset.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from glob import glob\n",
    "import os \n",
    "import pandas as pd \n",
    "import shutil\n",
    "from itertools import chain\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20,\n",
       " ['alpaca_chinese_dataset/其他中文问题补充\\\\三国问题.json',\n",
       "  'alpaca_chinese_dataset/其他中文问题补充\\\\企业管理问题.json',\n",
       "  'alpaca_chinese_dataset/其他中文问题补充\\\\传统诗词及文化常识问题.json',\n",
       "  'alpaca_chinese_dataset/其他中文问题补充\\\\党建类数据集.json',\n",
       "  'alpaca_chinese_dataset/其他中文问题补充\\\\其他问题.json'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_dir_list = ['alpaca_chinese_dataset/其他中文问题补充/',\n",
    "                   'alpaca_chinese_dataset/翻译后的中文数据/',\n",
    "                   'alpaca_chinese_dataset/chatglm问题数据补充/',\n",
    "                #    'alpaca_chinese_dataset/原始英文数据/'\n",
    "                   ]\n",
    "\n",
    "all_json_path = [glob(i+\"*.json\") for i in target_dir_list]\n",
    "all_json_path = list(chain(*all_json_path))\n",
    "len(all_json_path), all_json_path[:5]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_json(x:str):\n",
    "    try:\n",
    "        data = pd.read_json(x)\n",
    "        return data \n",
    "    except Exception as e:\n",
    "        return pd.DataFrame()\n",
    "\n",
    "alldata = pd.concat([read_json(i) for i in all_json_path])\n",
    "# alldata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "genrate_data_dir = \"data3_0328\"\n",
    "genrate_data_dir = Path(genrate_data_dir)\n",
    "\n",
    "if genrate_data_dir.exists():\n",
    "    shutil.rmtree(genrate_data_dir, ignore_errors=True)\n",
    "\n",
    "os.makedirs(genrate_data_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "28it [00:00, 397.31it/s]\n"
     ]
    }
   ],
   "source": [
    "alldata = alldata.sample(frac=1).reset_index(drop=True)\n",
    "\n",
    "chunk_size = 666\n",
    "\n",
    "for index, start_id in tqdm(enumerate(range(0, alldata.shape[0], chunk_size))):\n",
    "    temp_data = alldata.iloc[start_id:(start_id+chunk_size)]\n",
    "    temp_data.to_csv(genrate_data_dir.joinpath(f\"{index}.csv\"), index=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from thuglm.modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "# from thuglmcode.model_chatglm import ChatGLMForConditionalGeneration\n",
    "from transformers import Trainer, TrainingArguments\n",
    "import random\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "from typing import Optional\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"yuanzhoulvpi/chatglm6b-dddd\", trust_remote_code=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n",
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac25d881dd5e4d00b3d35248f0f2b81d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet\\lib\\site-packages\\peft\\tuners\\lora.py:173: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "model = AutoModel.from_pretrained(\n",
    "    \"yuanzhoulvpi/chatglm6b-dddd\", trust_remote_code=True).half().cuda()\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    "    inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
    "    # ['dense','dense_h_to_4h','dense_4h_to_h'] # 'query_key_value',\n",
    "    target_modules=['query_key_value',],\n",
    ")\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyTrainer(Trainer):\n",
    "    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n",
    "        # If we are executing this function, we are the process zero, so we don't check for that.\n",
    "        output_dir = output_dir if output_dir is not None else self.args.output_dir\n",
    "        os.makedirs(output_dir, exist_ok=True)\n",
    "        def save_tunable_parameters(model, path):\n",
    "            saved_params = {\n",
    "                k: v.to(\"cpu\") for k, v in model.named_parameters() if v.requires_grad\n",
    "            }\n",
    "            # saved_params = model.state_dict()\n",
    "            torch.save(saved_params, path)\n",
    "\n",
    "        save_tunable_parameters(\n",
    "            self.model, os.path.join(output_dir, \"chatglm-lora.pt\")\n",
    "        )\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(21, 7)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(42)\n",
    "\n",
    "all_file_list = glob(pathname=genrate_data_dir.joinpath(\"*.csv\").__str__())\n",
    "\n",
    "test_file_list = random.sample(all_file_list, int(len(all_file_list)*0.25))\n",
    "train_file_list = [i for i in all_file_list if i not in test_file_list]\n",
    "\n",
    "len(train_file_list), len(test_file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac088eb640b040c4b22bd1839a99078f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset csv/default to c:/Users/yuanz/PycharmProjects/zero_nlp/simple_thu_chatglm6b/cache_data/csv/default-94af94a7f1da5386/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a38376cc6ad745d5b3adff5eaa8906a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bf50389005c4897a4fd33e47eab2dc0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89c3fe519cde46d09519942d3686e142",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ada14f734db4f4e9309f0f1b28ec34c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating valid split: 0 examples [00:00, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset csv downloaded and prepared to c:/Users/yuanz/PycharmProjects/zero_nlp/simple_thu_chatglm6b/cache_data/csv/default-94af94a7f1da5386/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7daf9d998ea44d69dfa996a60507802",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['instruction', 'input', 'output'],\n",
       "        num_rows: 13986\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['instruction', 'input', 'output'],\n",
       "        num_rows: 4185\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = load_dataset(\n",
    "    \"csv\",\n",
    "    data_files={\n",
    "    'train':train_file_list,\n",
    "    'valid':test_file_list\n",
    "    },\n",
    "    cache_dir=\"cache_data\"\n",
    ")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_masks_and_position_ids(\n",
    "    seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True\n",
    "):\n",
    "    mask_position = (\n",
    "        seq_len - 2\n",
    "    )  # is equal to `seq.index(mask_token)` or `seq.index(150001)`\n",
    "    attention_mask = torch.ones((1, context_length, context_length), device=device)\n",
    "    attention_mask.tril_()\n",
    "    attention_mask[..., : mask_position - 1] = 1\n",
    "    attention_mask = (attention_mask < 0.5).bool()\n",
    "\n",
    "    if position_encoding_2d:\n",
    "        seq_length = seq_len - 1  # is equal to `seq_length = seq.index(150004)`\n",
    "        position_ids = torch.arange(context_length, dtype=torch.long, device=device)\n",
    "        if not gmask:\n",
    "            position_ids[seq_length:] = mask_position\n",
    "        block_position_ids = torch.cat(\n",
    "            (\n",
    "                torch.zeros(seq_length, dtype=torch.long, device=device),\n",
    "                torch.arange(\n",
    "                    context_length - seq_length, dtype=torch.long, device=device\n",
    "                )\n",
    "                + 1,\n",
    "            )\n",
    "        )\n",
    "        position_ids = torch.stack((position_ids, block_position_ids), dim=0)\n",
    "    else:\n",
    "        position_ids = torch.arange(context_length, dtype=torch.long, device=device)\n",
    "        if not gmask:\n",
    "            position_ids[context_length - 1 :] = mask_position\n",
    "    return attention_mask, position_ids\n",
    "\n",
    "def data_collator(features: list) -> dict:\n",
    "    len_ids = [len(feature[\"input_ids\"]) for feature in features]\n",
    "    longest = max(len_ids) + 1\n",
    "    input_ids = []\n",
    "    attention_mask_list = []\n",
    "    position_ids_list = []\n",
    "    labels_list = []\n",
    "    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):\n",
    "        ids = feature[\"input_ids\"]\n",
    "        seq_len = feature[\"seq_len\"]\n",
    "        labels = (\n",
    "            [-100] * (seq_len - 1)\n",
    "            + ids[(seq_len - 1) :]\n",
    "            + [tokenizer.eop_token_id]\n",
    "            + [-100] * (longest - ids_l - 1)\n",
    "        )\n",
    "        ids = ids + [tokenizer.eop_token_id] * (longest - ids_l)\n",
    "        _ids = torch.LongTensor(ids)\n",
    "        attention_mask, position_ids = get_masks_and_position_ids(\n",
    "            ids, seq_len, longest, _ids.device, gmask=False\n",
    "        )\n",
    "        labels_list.append(torch.LongTensor(labels))\n",
    "        input_ids.append(_ids)\n",
    "        attention_mask_list.append(attention_mask)\n",
    "        position_ids_list.append(position_ids)\n",
    "    input_ids = torch.stack(input_ids)\n",
    "    labels = torch.stack(labels_list)\n",
    "    attention_mask = torch.stack(attention_mask_list)\n",
    "    position_ids = torch.stack(position_ids_list)\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"labels\": labels,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"position_ids\": position_ids,\n",
    "    }\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00f459861ca94ef2b36e0e1e17b0cd08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/13986 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed1c91e259b44c9ea75c8c259a544acd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/4185 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e79e494e6f9d43e19a6dce330484e539",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/13986 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5c9b01ee1d94586a431f5b01a091570",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/4185 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet\\lib\\site-packages\\dill\\_dill.py:1845: PicklingWarning: Cannot locate reference to <class 'google.protobuf.pyext._message.CMessage'>.\n",
      "  warnings.warn('Cannot locate reference to %r.' % (obj,), PicklingWarning)\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet\\lib\\site-packages\\dill\\_dill.py:1847: PicklingWarning: Cannot pickle <class 'google.protobuf.pyext._message.CMessage'>: google.protobuf.pyext._message.CMessage has recursive self-references that trigger a RecursionError.\n",
      "  warnings.warn('Cannot pickle %r: %s.%s has recursive self-references that trigger a RecursionError.' % (obj, obj.__module__, obj_name), PicklingWarning)\n",
      "Parameter 'function'=<function preprocess at 0x000002A58560CCA0> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecac3272211d4f139716870ae7e9b7d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/13980 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4960625e5ab4832b257db516224e8cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/4182 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['context', 'target', 'input_ids', 'seq_len'],\n",
       "        num_rows: 13980\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['context', 'target', 'input_ids', 'seq_len'],\n",
       "        num_rows: 4182\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def format_example(example: dict) -> dict:\n",
    "    context = f\"Instruction: {example['instruction']}\\n\"\n",
    "    if example.get(\"input\"):\n",
    "        context += f\"Input: {example['input']}\\n\"\n",
    "    context += \"Answer: \"\n",
    "    target = example[\"output\"]\n",
    "    # {\"context\": context, \"target\": target}\n",
    "    example['context'] = context\n",
    "    example['target'] = target\n",
    "    return example\n",
    "\n",
    "max_seq_length = 512\n",
    "\n",
    "def preprocess(example):\n",
    "    prompt = example[\"context\"]\n",
    "    target = example[\"target\"]\n",
    "    prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)\n",
    "    target_ids = tokenizer.encode(\n",
    "        target, max_length=max_seq_length, truncation=True, add_special_tokens=False\n",
    "    )\n",
    "    input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]\n",
    "    return {\"input_ids\": input_ids, \"seq_len\": len(prompt_ids)}\n",
    "\n",
    "def filter_nan(example):\n",
    "    return example['target'] is not None and example['context'] is not  None\n",
    "\n",
    "\n",
    "tokenized_datasets = dataset.map(\n",
    "    function=format_example, remove_columns=dataset['train'].column_names\n",
    "    ).filter(function=filter_nan)\n",
    "tokenized_datasets = tokenized_datasets.map(function=preprocess)\n",
    "tokenized_datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.gradient_checkpointing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.gradient_checkpointing = True\n",
    "# model._set_gradient_checkpointing(value=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl\n",
    "\n",
    "\n",
    "\n",
    "class EmptyCacheCallBack(TrainerCallback):\n",
    "    \"\"\"\n",
    "    通过callback的形式，解决显存不够的问题\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs, **kwargs):\n",
    "        \"\"\"\n",
    "        Event called after logging the last logs.\n",
    "        \"\"\"\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):\n",
    "        torch.cuda.empty_cache()\n",
    "        \n",
    "    \n",
    "\n",
    "eccb = EmptyCacheCallBack()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet\\lib\\site-packages\\transformers\\optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='51' max='873' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 51/873 02:16 < 38:17, 0.36 it/s, Epoch 0.06/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "    <div>\n",
       "      \n",
       "      <progress value='502' max='4182' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 502/4182 00:41 < 05:06, 11.99 it/s]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"test004\",\n",
    "    per_device_train_batch_size=2, \n",
    "    per_device_eval_batch_size=1,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    logging_steps=50,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1_000,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=100,\n",
    "    fp16=True,\n",
    "    push_to_hub=False,\n",
    "    remove_unused_columns=False\n",
    ")\n",
    "\n",
    "trainer = MyTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"valid\"],\n",
    "    # callbacks=[eccb]\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
